9a9c70
@@ -16,17 +16,16 @@
package org.springframework.data.neo4j.repository.support;
 import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Collection;
-import java.util.List;
 import java.util.Optional;
 
 import org.neo4j.ogm.cypher.query.Pagination;
 import org.neo4j.ogm.session.Session;
 import org.springframework.data.domain.Page;
-import org.springframework.data.domain.PageImpl;
 import org.springframework.data.domain.Pageable;
 import org.springframework.data.domain.Sort;
 import org.springframework.data.neo4j.repository.Neo4jRepository;
 import org.springframework.data.neo4j.util.PagingAndSortingUtils;
+import org.springframework.data.repository.support.PageableExecutionUtils;
 import org.springframework.stereotype.Repository;
 import org.springframework.transaction.annotation.Transactional;
 import org.springframework.util.Assert;
@@ -36,12 +35,12 @@
import org.springframework.util.Assert;
  * you a more sophisticated interface than the plain {@link Session} .
  *
  * @param <T> the type of the entity to handle
- *
  * @author Vince Bickers
  * @author Luanne Misquitta
  * @author Mark Angrish
  * @author Mark Paluch
  * @author Jens Schauder
+ * @author Gerrit Meier
  */
 @Repository
 @Transactional(readOnly = true)
@@ -190,24 +189,9 @@
public class SimpleNeo4jRepository<T, ID extends Serializable> implements Neo4jR
 
 	@Override
 	public Page<T> findAll(Pageable pageable, int depth) {
-		Collection<T> data = session.loadAll(clazz, PagingAndSortingUtils.convert(pageable.getSort())
-				, new Pagination(pageable.getPageNumber(), pageable.getPageSize()), depth);
-		return updatePage(pageable, new ArrayList<>(data));
-	}
-
-	/*
-	 * This is a cheap trick to estimate the total number of objects without actually knowing the real value.
-	 * Essentially, if the result size is the same as the page size, we assume more data can be fetched, so
-	 * we set the expected total to the current total retrieved so far + the current page size. As soon as the
-	 * result size is less than the page size, we know there are no more, so we set the total to the number
-	 * retrieved so far. This will ensure that page.next() returns false.
-	 */
-	private Page<T> updatePage(Pageable pageable, List<T> results) {
-
-		int pageSize = pageable.getPageSize();
-		long pageOffset = pageable.getOffset();
-		long total = pageOffset + results.size() + (results.size() == pageSize ? pageSize : 0);
+		Pagination pagination = new Pagination(pageable.getPageNumber(), pageable.getPageSize());
+		Collection<T> data = session.loadAll(clazz, PagingAndSortingUtils.convert(pageable.getSort()), pagination, depth);
 
-		return new PageImpl<T>(results, pageable, total);
+		return PageableExecutionUtils.getPage(new ArrayList<>(data), pageable, () -> session.countEntitiesOfType(clazz));
 	}
 }
